/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/* !
 * \file kernel_micro_datacopy_impl.h
 * \brief
 */
#ifndef ASCENDC_MODULE_MICRO_DATACOPY_STORE_IMPL_H
#define ASCENDC_MODULE_MICRO_DATACOPY_STORE_IMPL_H

namespace AscendC {
namespace MicroAPI {
template <int InputNum, StoreDist dist> __simd_callee__ inline void CheckStoreDist()
{
    if constexpr (InputNum == 1) {
        static_assert(SupportEnum<dist, StoreDist::DIST_NORM_B8, StoreDist::DIST_NORM_B16, StoreDist::DIST_NORM_B32,
            StoreDist::DIST_FIRST_ELEMENT_B8, StoreDist::DIST_FIRST_ELEMENT_B16, StoreDist::DIST_FIRST_ELEMENT_B32,
            StoreDist::DIST_PACK_B16, StoreDist::DIST_PACK_B32, StoreDist::DIST_PACK_B64, StoreDist::DIST_PACK4_B32,
            StoreDist::DIST_NORM>(),
            "DataCopy not support this dist on current device");
    } else {
        static_assert(
            SupportEnum<dist, StoreDist::DIST_INTLV_B8, StoreDist::DIST_INTLV_B16, StoreDist::DIST_INTLV_B32>(),
            "DataCopy not support this dist on current device");
    }
}

// vsts
template <typename T = DefaultType, StoreDist dist = StoreDist::DIST_NORM, typename RegT>
__simd_callee__ inline void DataCopyImpl(__local_mem__ T *dstUbAddr, RegT &srcReg, MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
    static_assert(CheckRegTrait<RegT, RegTraitNumOne>() || CheckRegTrait<RegT, RegTraitNumTwo>(),
        "RegTensor only suppoort RegTraitNumOne or RegTraitNumTwo on current device!");
    CheckStoreDist<1, dist>();
    constexpr auto distValue = std::integral_constant<::DistVST, static_cast<::DistVST>(GetStoreDist<T, dist>())>();
    if constexpr (SupportType<ActualT, fp4x2_e2m1_t, fp4x2_e1m2_t, hifloat8_t, fp8_e5m2_t, fp8_e4m3fn_t,
        int4x2_t>()) {
        vsts((RegTensor<uint8_t> &)srcReg, (__ubuf__ uint8_t *)dstUbAddr, 0, distValue, mask);
    } else if constexpr (SupportBytes<ActualT, 8>()) {
        MaskReg dstMask0;
        MaskReg dstMask1;
        if constexpr (CheckRegTrait<RegT, RegTraitNumOne>()) {
            constexpr auto lowerDist =
                std::integral_constant<::HiloPart, static_cast<::HiloPart>(HighLowPart::LOWEST)>();
            MaskReg tmpMask;
            ppack(tmpMask, mask, lowerDist);
            pintlv_b32(dstMask0, dstMask1, tmpMask, tmpMask);
            vsts((RegTensor<uint32_t> &)srcReg, (__ubuf__ uint32_t *)dstUbAddr, 0, distValue, dstMask0);
        } else if constexpr (CheckRegTrait<RegT, RegTraitNumTwo>()) {
            RegTensor<uint32_t> reg0;
            RegTensor<uint32_t> reg1;
            pintlv_b32(dstMask0, dstMask1, mask, mask);
            Interleave(reg0, reg1, (RegTensor<uint32_t> &)srcReg.reg[0], (RegTensor<uint32_t> &)srcReg.reg[1]);
            vsts((RegTensor<uint32_t> &)reg0, (__local_mem__ uint32_t *)dstUbAddr, 0, distValue, dstMask0);
            vsts((RegTensor<uint32_t> &)reg1, (__local_mem__ uint32_t *)dstUbAddr, VECTOR_REG_WIDTH / sizeof(uint32_t),
                distValue, dstMask1);
        }
    } else {
        if constexpr(SupportType<ActualT, complex32>() && (CheckRegTrait<RegT, RegTraitNumTwo>())) {
            MaskReg dstMask0;
            MaskReg dstMask1;
            RegTensor<uint16_t> reg0;
            RegTensor<uint16_t> reg1;
            pintlv_b16(dstMask0, dstMask1, mask, mask);
            Interleave(reg0, reg1, (RegTensor<uint16_t> &)srcReg.reg[0], (RegTensor<uint16_t> &)srcReg.reg[1]);
            vsts((RegTensor<uint16_t> &)reg0, (__local_mem__ uint16_t *)dstUbAddr, 0, distValue, dstMask0);
            vsts((RegTensor<uint16_t> &)reg1, (__local_mem__ uint16_t *)dstUbAddr, VECTOR_REG_WIDTH / sizeof(uint16_t),
                distValue, dstMask1);
        } else {
            static_assert(SupportBytes<ActualT, 1, 2, 4, 8>(),
                "DataCopy only support type b8/b16/b32/b64 on current device");
            if constexpr (std::is_same_v<T, bool>) {
                vsts((RegTensor<int8_t> &)srcReg, (__ubuf__ int8_t *)dstUbAddr, 0, distValue, mask);
            } else if constexpr (SupportBytes<ActualT, 4>()) {
                vsts((RegTensor<int32_t> &)srcReg, (__ubuf__ int32_t *)dstUbAddr, 0, distValue, mask);
            } else {
                vsts(srcReg, dstUbAddr, 0, distValue, mask);
            }
        }
    }
}

// vsts postupdate
template <typename T = DefaultType, PostLiteral postMode, StoreDist dist = StoreDist::DIST_NORM, typename RegT>
__simd_callee__ inline void DataCopyImpl(__local_mem__ T *&dstUbAddr, RegT &srcReg, int32_t postUpdateStride, MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
    CheckStoreDist<1, dist>();
    constexpr auto distValue = std::integral_constant<::DistVST, static_cast<::DistVST>(GetStoreDist<T, dist>())>();
    constexpr auto postValue = std::integral_constant<::Post, static_cast<::Post>(postMode)>();
    if constexpr (SupportType<ActualT, fp4x2_e2m1_t, fp4x2_e1m2_t, hifloat8_t, fp8_e5m2_t, fp8_e4m3fn_t, int4x2_t>()) {
        vsts((RegTensor<uint8_t> &)srcReg, (__ubuf__ uint8_t *&)dstUbAddr, postUpdateStride, distValue, mask, postValue);
    } else if constexpr (SupportBytes<ActualT, 8>()) {
        MaskReg dstMask0, dstMask1;
        if constexpr (CheckRegTrait<RegT, RegTraitNumOne>()) {
            constexpr auto lowerDist = std::integral_constant<::HiloPart, static_cast<::HiloPart>(HighLowPart::LOWEST)>();
            MaskReg tmpMask;
            ppack(tmpMask, mask, lowerDist);
            pintlv_b32(dstMask0, dstMask1, tmpMask, tmpMask);
            vsts((RegTensor<uint32_t> &)srcReg, (__ubuf__ uint32_t *&)dstUbAddr, postUpdateStride * 2, distValue,
                dstMask0, postValue);
        } else if constexpr (CheckRegTrait<RegT, RegTraitNumTwo>()) {
            RegTensor<uint32_t> reg0, reg1;
            pintlv_b32(dstMask0, dstMask1, mask, mask);
            Interleave(reg0, reg1, (RegTensor<uint32_t> &)srcReg.reg[0], (RegTensor<uint32_t> &)srcReg.reg[1]);
            constexpr uint32_t one_repeat_num = VECTOR_REG_WIDTH / sizeof(ActualT);
            uint32_t tmpStride1 = (postUpdateStride > one_repeat_num) ? one_repeat_num : postUpdateStride;
            uint32_t tmpStride2 = (postUpdateStride > one_repeat_num) ? postUpdateStride - one_repeat_num : 0;
            vsts((RegTensor<uint32_t> &)reg0, (__local_mem__ uint32_t *&)dstUbAddr, tmpStride1 * 2, distValue, dstMask0,
                postValue);
            vsts((RegTensor<uint32_t> &)reg1, (__local_mem__ uint32_t *&)dstUbAddr, tmpStride2 * 2, distValue, dstMask1,
                postValue);
        }
    } else {
        if constexpr(SupportType<ActualT, complex32>() && (CheckRegTrait<RegT, RegTraitNumTwo>())) {
            MaskReg dstMask0;
            MaskReg dstMask1;
            RegTensor<uint16_t> reg0;
            RegTensor<uint16_t> reg1;
            pintlv_b16(dstMask0, dstMask1, mask, mask);
            Interleave(reg0, reg1, (RegTensor<uint16_t> &)srcReg.reg[0], (RegTensor<uint16_t> &)srcReg.reg[1]);
            static constexpr uint32_t one_repeat_num = VECTOR_REG_WIDTH / sizeof(ActualT);
            uint32_t tmpStride1 = (postUpdateStride > one_repeat_num) ? one_repeat_num : postUpdateStride;
            uint32_t tmpStride2 = (postUpdateStride > one_repeat_num) ? postUpdateStride - one_repeat_num : 0;
            vsts((RegTensor<uint16_t> &)reg0, (__local_mem__ uint16_t *&)dstUbAddr, tmpStride1 * 2, distValue, dstMask0,
                postValue);
            vsts((RegTensor<uint16_t> &)reg1, (__local_mem__ uint16_t *&)dstUbAddr, tmpStride2 * 2, distValue, dstMask1,
                postValue);
        } else {
            static_assert(SupportBytes<ActualT, 1, 2, 4, 8>(),
                "DataCopy only support type b8/b16/b32/b64 on current device");
            if constexpr (std::is_same_v<T, bool>) {
                vsts((RegTensor<int8_t> &)srcReg, (__ubuf__ int8_t *&)dstUbAddr, postUpdateStride, distValue, mask, postValue);
            } else if constexpr (SupportBytes<ActualT, 4>()) {
                vsts((RegTensor<int32_t> &)srcReg, (__ubuf__ int32_t *&)dstUbAddr, postUpdateStride, distValue, mask, postValue);
            } else {
                vsts(srcReg, dstUbAddr, postUpdateStride, distValue, mask, postValue);
            }
        }
    }
}

// vst areg
template <typename T = DefaultType, StoreDist dist = StoreDist::DIST_NORM, typename RegT>
__simd_callee__ inline void DataCopyImpl(__local_mem__ T *dstUbAddr, RegT &srcReg, AddrReg offset, MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
    static_assert(CheckRegTrait<RegT, RegTraitNumOne>(), "RegTensor only suppoort RegTraitNumOne on current device!");
    CheckStoreDist<1, dist>();
    constexpr auto distValue = std::integral_constant<::DistVST, static_cast<::DistVST>(GetStoreDist<T, dist>())>();
    if constexpr (SupportType<ActualT, fp4x2_e2m1_t, fp4x2_e1m2_t, hifloat8_t, fp8_e5m2_t, fp8_e4m3fn_t,
        int4x2_t>()) {
        vst((RegTensor<uint8_t> &)srcReg, (__ubuf__ uint8_t *)dstUbAddr, offset, distValue, mask);
    } else {
        static_assert(SupportBytes<ActualT, 1, 2, 4, 8>(),
            "DataCopy only support type b8/b16/b32/b64 on current device");
        if constexpr (std::is_same_v<T, bool>) {
            vst((RegTensor<int8_t> &)srcReg, (__ubuf__ int8_t *)dstUbAddr, offset, distValue, mask);
        } else if constexpr (SupportBytes<T, 4>()) {
            vst((RegTensor<int32_t> &)srcReg, (__ubuf__ int32_t *)dstUbAddr, offset, distValue, mask);
        } else if constexpr (SupportBytes<T, 8>()) {
            // using b32 vst to simulate b64 vst
            MaskReg tmpMask;
            MaskReg emptyMask;
            MaskPack(tmpMask, mask);
            pintlv_b32(tmpMask, emptyMask, tmpMask, tmpMask);
            vst((RegTensor<int32_t> &)srcReg, (__ubuf__ int32_t *)dstUbAddr, offset, distValue, tmpMask);
        } else {
            vst(srcReg, dstUbAddr, offset, distValue, mask);
        }
    }
}

// vsts dual
template <typename T = DefaultType, StoreDist dist, typename RegT>
__simd_callee__ inline void DataCopyImpl(__local_mem__ T *dstUbAddr, RegT &srcReg0, RegT &srcReg1, MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
    static_assert(CheckRegTrait<RegT, RegTraitNumOne>(), "RegTensor only suppoort RegTraitNumOne on current device!");
    CheckStoreDist<2, dist>();
    constexpr auto distValue = std::integral_constant<::DistVST, static_cast<::DistVST>(GetStoreDist<T, dist>())>();
    if constexpr (SupportType<ActualT, fp4x2_e2m1_t, fp4x2_e1m2_t, hifloat8_t, fp8_e5m2_t, fp8_e4m3fn_t,
        int4x2_t>()) {
        vsts((RegTensor<uint8_t> &)srcReg0, (RegTensor<uint8_t> &)srcReg1, (__ubuf__ uint8_t *)dstUbAddr, 0, distValue,
            mask);
    } else if constexpr (SupportType<ActualT, float>()) { // ccec no float signature
        vsts((RegTensor<uint32_t> &)srcReg0, (RegTensor<uint32_t> &)srcReg1, (__ubuf__ uint32_t *)dstUbAddr, 0,
            distValue, mask);
    } else {
        static_assert(SupportBytes<ActualT, 1, 2, 4, 8>(), "DataCopy only support type b8/b16/b32/b64 on current device");
        if constexpr (std::is_same_v<T, bool>) {
            vsts((RegTensor<int8_t> &)srcReg0, (RegTensor<int8_t> &)srcReg1, (__ubuf__ int8_t *)dstUbAddr, 0, distValue, mask);
        } else if constexpr (SupportBytes<ActualT, 4>()) {
            vsts(
                (RegTensor<int32_t> &)srcReg0, (RegTensor<int32_t> &)srcReg1, (__ubuf__ int32_t *)dstUbAddr, 0, distValue, mask);
        } else if constexpr (SupportBytes<T, 8>()) {
            // using b32 vst to simulate b64 vst
            MaskReg tmpMask;
            MaskReg emptyMask;
            MaskPack(tmpMask, mask);
            pintlv_b32(tmpMask, emptyMask, tmpMask, tmpMask);
            vsts(
                (RegTensor<int32_t> &)srcReg0, (RegTensor<int32_t> &)srcReg1, (__ubuf__ int32_t *)dstUbAddr, 0, distValue, tmpMask);
        } else {
            vsts(srcReg0, srcReg1, dstUbAddr, 0, distValue, mask);
        }
    }
}

// vsts dual areg
template <typename T = DefaultType, StoreDist dist, typename RegT>
__simd_callee__ inline void DataCopyImpl(__local_mem__ T *dstUbAddr, RegT &srcReg0, RegT &srcReg1, AddrReg offset,
    MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
    static_assert(CheckRegTrait<RegT, RegTraitNumOne>(), "RegTensor only suppoort RegTraitNumOne on current device!");
    CheckStoreDist<2, dist>();
    constexpr auto distValue = std::integral_constant<::DistVST, static_cast<::DistVST>(GetStoreDist<T, dist>())>();
    if constexpr (SupportType<ActualT, fp4x2_e2m1_t, fp4x2_e1m2_t, hifloat8_t, fp8_e5m2_t, fp8_e4m3fn_t,
        int4x2_t>()) {
        vst((RegTensor<uint8_t> &)srcReg0, (RegTensor<uint8_t> &)srcReg1, (__ubuf__ uint8_t *)dstUbAddr, offset,
            distValue, mask);
    } else if constexpr (SupportType<ActualT, float>()) { // ccec no float signature
        vst((RegTensor<uint32_t> &)srcReg0, (RegTensor<uint32_t> &)srcReg1, (__ubuf__ uint32_t *)dstUbAddr, offset,
            distValue, mask);
    } else {
        static_assert(SupportBytes<ActualT, 1, 2, 4, 8>(), "DataCopy only support type b8/b16/b32/b64 on current device");
        if constexpr (std::is_same_v<T, bool>) {
            vst((RegTensor<int8_t> &)srcReg0, (RegTensor<int8_t> &)srcReg1, (__ubuf__ int8_t *)dstUbAddr,
                offset, distValue, mask);
        } else if constexpr (SupportBytes<ActualT, 4>()) {
            vst((RegTensor<int32_t> &)srcReg0, (RegTensor<int32_t> &)srcReg1, (__ubuf__ int32_t *)dstUbAddr,
                offset, distValue, mask);
        } else if constexpr (SupportBytes<T, 8>()) {
            // using b32 vst to simulate b64 vst
            MaskReg tmpMask;
            MaskReg emptyMask;
            MaskPack(tmpMask, mask);
            pintlv_b32(tmpMask, emptyMask, tmpMask, tmpMask);
            vst((RegTensor<int32_t> &)srcReg0, (RegTensor<int32_t> &)srcReg1, (__ubuf__ int32_t *)dstUbAddr,
                offset, distValue, tmpMask);
        } else {
            vst(srcReg0, srcReg1, dstUbAddr, offset, distValue, mask);
        }
    }
}

// vsstb
template <typename T = DefaultType, DataCopyMode dataMode, typename RegT>
__simd_callee__ inline void DataCopyImpl(__local_mem__ T *dstUbAddr, RegT &srcReg, uint32_t dataBlockStride, MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
    static_assert(CheckRegTrait<RegT, RegTraitNumOne>(), "RegTensor only suppoort RegTraitNumOne on current device!");
    if constexpr (SupportType<ActualT, fp4x2_e2m1_t, fp4x2_e1m2_t, hifloat8_t, fp8_e5m2_t, fp8_e4m3fn_t,
        int4x2_t>()) {
        vsstb((RegTensor<uint8_t> &)srcReg, (__ubuf__ uint8_t *)dstUbAddr, (dataBlockStride << 16u), mask);
    } else {
        static_assert(SupportBytes<ActualT, 1, 2, 4>(), "DataCopy only support type b8/b16/b32 on current device");
        if constexpr (std::is_same_v<T, bool>) {
            vsstb((RegTensor<int8_t> &)srcReg, (__ubuf__ int8_t *)dstUbAddr, (dataBlockStride << 16u), mask);
        } else if constexpr (std::is_same_v<T, complex32>) {
            vsstb((RegTensor<int32_t> &)srcReg, (__ubuf__ int32_t *)dstUbAddr, (dataBlockStride << 16u), mask);
        } else {
            vsstb(srcReg, dstUbAddr, (dataBlockStride << 16u), mask);
        }
    }
}

template <typename T = DefaultType, DataCopyMode dataMode, PostLiteral postMode, typename RegT>
__simd_callee__ inline void DataCopyImpl(__local_mem__ T *&dstUbAddr, RegT &srcReg, uint32_t dataBlockStride,
    uint32_t repeatStride, MaskReg &mask)
{
    if constexpr (postMode == PostLiteral::POST_MODE_NORMAL) {
        DataCopyImpl<T, dataMode, RegT>(dstUbAddr, srcReg, dataBlockStride, mask);
    }  else {
        using ActualT = typename RegT::ActualT;
        static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
        static_assert(CheckRegTrait<RegT, RegTraitNumOne>(), "RegTensor only suppoort RegTraitNumOne on current device!");
        if constexpr (SupportType<ActualT, fp4x2_e2m1_t, fp4x2_e1m2_t, hifloat8_t, fp8_e5m2_t, fp8_e4m3fn_t,
            int4x2_t>()) {
            constexpr auto postValue = std::integral_constant<::Post, static_cast<::Post>(postMode)>();
            vsstb((RegTensor<uint8_t> &)srcReg, (__ubuf__ uint8_t *&)dstUbAddr,
                (dataBlockStride << 16u) | (repeatStride & 0xFFFFU), mask, postValue);
        } else {
            static_assert(SupportBytes<ActualT, 1, 2, 4>(), "DataCopy only support type b8/b16/b32 on current device");
            constexpr auto postValue = std::integral_constant<::Post, static_cast<::Post>(postMode)>();
            if constexpr (std::is_same_v<T, bool>) {
                vsstb((RegTensor<int8_t> &)srcReg, (__ubuf__ int8_t *&)dstUbAddr,
                    (dataBlockStride << 16u) | (repeatStride & 0xFFFFU), mask, postValue);
            } else if constexpr (std::is_same_v<T, complex32>) {
                vsstb((RegTensor<int32_t> &)srcReg, (__ubuf__ int32_t *&)dstUbAddr,
                    (dataBlockStride << 16u) | (repeatStride & 0xFFFFU), mask, postValue);
            } else {
                vsstb(srcReg, dstUbAddr, (dataBlockStride << 16u) | (repeatStride & 0xFFFFU), mask, postValue);
            }
        }
    }
}

// vstus/vstas
template <typename T = DefaultType, PostLiteral postMode = PostLiteral::POST_MODE_UPDATE, typename RegT>
__simd_callee__ inline void DataCopyUnAlignImpl(__local_mem__ T *&dstUbAddr, RegT &srcReg, UnalignReg &ureg,
    uint32_t postUpdateStride)
{
    using ActualT = typename RegT::ActualT;
    static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
    static_assert(SupportBytes<ActualT, 1, 2, 4, 8>(),
        "DataCopyUnAlign only support type b8/b16/b32/b64 on current device");
    constexpr auto postValue = std::integral_constant<::Post, static_cast<::Post>(postMode)>();
    if constexpr (sizeof(ActualT) == 8) {
        if constexpr (CheckRegTrait<RegT, RegTraitNumOne>()) {
            vstus(ureg, postUpdateStride * 2, (RegTensor<uint32_t> &)srcReg, (__local_mem__ uint32_t *&)dstUbAddr,
                postValue);
        } else if constexpr (CheckRegTrait<RegT, RegTraitNumTwo>()) {
            RegTensor<uint32_t> tmp1;
            RegTensor<uint32_t> tmp2;
            Interleave(tmp1, tmp2, (RegTensor<uint32_t> &)srcReg.reg[0], (RegTensor<uint32_t> &)srcReg.reg[1]);
            constexpr uint32_t one_repeat_num = VECTOR_REG_WIDTH / sizeof(ActualT);
            uint32_t tmpStride1 = (postUpdateStride > one_repeat_num) ? one_repeat_num : postUpdateStride;
            vstus(ureg, tmpStride1 * 2, tmp1, (__local_mem__ uint32_t *&)dstUbAddr, postValue);
            uint32_t tmpStride2 = (postUpdateStride > one_repeat_num) ? (postUpdateStride - one_repeat_num) : 0;
            vstus(ureg, tmpStride2 * 2, tmp2, (__local_mem__ uint32_t *&)dstUbAddr, postValue);
        }
    } else {
        if constexpr(SupportType<ActualT, complex32>() && (CheckRegTrait<RegT, RegTraitNumTwo>())) {
            RegTensor<uint16_t> tmp1;
            RegTensor<uint16_t> tmp2;
            Interleave(tmp1, tmp2, (RegTensor<uint16_t> &)srcReg.reg[0], (RegTensor<uint16_t> &)srcReg.reg[1]);
            constexpr uint32_t one_repeat_num = VECTOR_REG_WIDTH / sizeof(ActualT);
            uint32_t tmpStride1 = (postUpdateStride > one_repeat_num) ? one_repeat_num : postUpdateStride;
            vstus(ureg, tmpStride1 * 2, tmp1, (__local_mem__ uint16_t *&)dstUbAddr, postValue);
            uint32_t tmpStride2 = (postUpdateStride > one_repeat_num) ? (postUpdateStride - one_repeat_num) : 0;
            vstus(ureg, tmpStride2 * 2, tmp2, (__local_mem__ uint16_t *&)dstUbAddr, postValue);
        } else {
            if constexpr (std::is_same_v<T, bool>) {
                vstus(ureg, postUpdateStride, (RegTensor<int8_t> &)srcReg, (__ubuf__ int8_t *&)dstUbAddr, postValue);
            } else if constexpr (SupportBytes<T, 4>()) {
                vstus(ureg, postUpdateStride, (RegTensor<int32_t> &)srcReg, (__ubuf__ int32_t *&)dstUbAddr, postValue);
            } else {
                vstus(ureg, postUpdateStride, srcReg, dstUbAddr, postValue);
            }
        }
    }
}

template <typename T, PostLiteral postMode = PostLiteral::POST_MODE_UPDATE>
__simd_callee__ inline void DataCopyUnAlignPostImpl(__local_mem__ T *&dstUbAddr, UnalignReg &ureg, int32_t postUpdateStride)
{
    static_assert(SupportBytes<T, 1, 2, 4, 8>(),
        "DataCopyUnAlignPost only support type b8/b16/b32/b64 on current device");
    if constexpr (sizeof(T) == 8) {
        if constexpr (postMode == PostLiteral::POST_MODE_UPDATE) {
            vstas(ureg, (__local_mem__ uint32_t *&)dstUbAddr, postUpdateStride * 2, POST_UPDATE);
        } else {
            vstas(ureg, (__local_mem__ uint32_t *&)dstUbAddr, postUpdateStride * 2);
        }
    } else {
        if constexpr (postMode == PostLiteral::POST_MODE_UPDATE) {
            if constexpr (std::is_same_v<T, bool>) {
                vstas(ureg, (__ubuf__ int8_t *&)dstUbAddr, postUpdateStride, POST_UPDATE);
            } else if constexpr (SupportBytes<T, 4>()) {
                vstas(ureg, (__ubuf__ int32_t *&)dstUbAddr, postUpdateStride, POST_UPDATE);
            } else {
                vstas(ureg, dstUbAddr, postUpdateStride, POST_UPDATE);
            }
        } else {
            if constexpr (std::is_same_v<T, bool>) {
                vstas(ureg, (__ubuf__ int8_t *&)dstUbAddr, postUpdateStride);
            } else if constexpr (SupportBytes<T, 4>()) {
                vstas(ureg, (__ubuf__ int32_t *&)dstUbAddr, postUpdateStride);
            } else {
                vstas(ureg, dstUbAddr, postUpdateStride);
            }
        }
    }
}

// vstu/vsta
template <typename T = DefaultType, PostLiteral postMode = PostLiteral::POST_MODE_UPDATE, typename RegT>
__simd_callee__ inline void DataCopyUnAlignImpl(__local_mem__ T *&dstUbAddr, RegT &srcReg, UnalignReg &ureg, AddrReg &areg)
{
    using ActualT = typename RegT::ActualT;
    static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
    static_assert(CheckRegTrait<RegT, RegTraitNumOne>(), "RegTensor only suppoort RegTraitNumOne on current device!");
    static_assert(SupportBytes<ActualT, 1, 2, 4, 8>(), "only support type b8/b16/b32/b64 on current device");
    constexpr auto postValue = std::integral_constant<::Post, static_cast<::Post>(postMode)>();
    if constexpr (std::is_same_v<T, bool>) {
        vstu(ureg, areg, (RegTensor<int8_t> &)srcReg, (__ubuf__ int8_t *&)dstUbAddr, postValue);
    } else if constexpr (SupportBytes<T, 4>()) {
        vstu(ureg, areg, (RegTensor<int32_t> &)srcReg, (__ubuf__ int32_t *&)dstUbAddr, postValue);
    } else if constexpr (SupportBytes<T, 8>()) {
        vstu(ureg, areg, (RegTensor<int32_t> &)srcReg, (__ubuf__ int32_t *&)dstUbAddr, postValue);
    } else {
        vstu(ureg, areg, srcReg, dstUbAddr, postValue);
    }
}

template <typename T>
__simd_callee__ inline void DataCopyUnAlignPostImpl(__local_mem__ T *&dstUbAddr, UnalignReg &ureg, AddrReg &areg)
{
    static_assert(SupportBytes<T, 1, 2, 4, 8>(), "only support type b8/b16/b32/b64 on current device");
    if constexpr (std::is_same_v<T, bool>) {
        vsta(ureg, (__ubuf__ int8_t *&)dstUbAddr, areg);
    } else if constexpr (SupportBytes<T, 4>()) {
        vsta(ureg, (__ubuf__ int32_t *&)dstUbAddr, areg);
    } else if constexpr (SupportBytes<T, 8>()) {
        vsta(ureg, (__ubuf__ int32_t *&)dstUbAddr, areg);
    } else {
        vsta(ureg, dstUbAddr, areg);
    }
}

// vstur/vstar
template <typename T = DefaultType, PostLiteral postMode = PostLiteral::POST_MODE_UPDATE, typename RegT>
__simd_callee__ inline void DataCopyUnAlignImpl(__local_mem__ T *dstUbAddr, RegT &srcReg, UnalignReg &ureg)
{
    using ActualT = typename RegT::ActualT;
    static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
    static_assert(SupportBytes<ActualT, 1, 2, 4, 8>(), "only support type b8/b16/b32/b64 on current device");
    static_assert(CheckRegTrait<RegT, RegTraitNumOne>(), "RegTensor only support RegTraitNumOne on current device!");
    constexpr auto postValue = std::integral_constant<::Post, static_cast<::Post>(postMode)>();
    if constexpr (std::is_same_v<T, bool>) {
        vstur(ureg, (RegTensor<int8_t> &)srcReg, (__ubuf__ int8_t *)dstUbAddr, postValue);
    } else if constexpr (SupportBytes<T, 4>()) {
        vstur(ureg, (RegTensor<int32_t> &)srcReg, (__ubuf__ int32_t *)dstUbAddr, postValue);
    } else if constexpr (SupportBytes<T, 8>()) {
        vstur(ureg, (RegTensor<int64_t> &)srcReg, (__ubuf__ int64_t *)dstUbAddr, postValue);
    } else {
        vstur(ureg, srcReg, dstUbAddr, postValue);
    }
}

template <typename T> __simd_callee__ inline void DataCopyUnAlignPostImpl(__local_mem__ T *dstUbAddr, UnalignReg &ureg)
{
    static_assert(SupportBytes<T, 1, 2, 4, 8>(), "only support type b8/b16/b32/b64 on current device");
    if constexpr (std::is_same_v<T, bool>) {
        vstar(ureg, (__ubuf__ int8_t *)dstUbAddr);
    } else if constexpr (SupportBytes<T, 4>()) {
        vstar(ureg, (__ubuf__ int32_t *)dstUbAddr);
    } else if constexpr (SupportBytes<T, 8>()) {
        vstar(ureg, (__ubuf__ int64_t *)dstUbAddr);
    } else {
        vstar(ureg, dstUbAddr);
    }
}
} // namespace MicroAPI
} // namespace AscendC
#endif // ASCENDC_MODULE_MICRO_DATACOPY_STORE_IMPL_H
